# Rekall Memory Forensics
# Copyright (c) 2010, 2011, 2012 Michael Ligh <michael.ligh@mnin.org>
# Copyright 2013 Google Inc. All Rights Reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at
# your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
from rekall.plugins.windows import common
from rekall.plugins.windows import filescan
from rekall_lib import utils


class DeviceTree(common.PoolScannerPlugin):
    "Show device tree."

    __name = "devicetree"

    table_header = [
        dict(name="Type", type="TreeNode", width=10, max_depth=10),
        dict(name="Address", style="address", padding="0"),
        dict(name="Name", width=30),
        dict(name="device_type", width=30),
        dict(name="Path"),
    ]

    scanner_defaults = dict(
        scan_kernel_nonpaged_pool=True
    )

    def generate_hits(self):
        for run in self.generate_memory_ranges():
            scanner = filescan.PoolScanDriver(
                session=self.session, profile=self.profile,
                address_space=run.address_space)

            for pool_obj in scanner.scan(run.start, run.length):
                for object_obj in pool_obj.IterObject("Driver", freed=True):
                    yield object_obj.Object

    def collect(self):
        for driver_obj in self.generate_hits():
            yield dict(
                Type=utils.AttributedString("DRV", [(0, 30, "BLACK", "RED")]),
                Address=driver_obj.obj_offset,
                Name=driver_obj.DriverName.v(vm=self.kernel_address_space),
                depth=0)

            first_device = driver_obj.DeviceObject.dereference(
                vm=self.kernel_address_space)

            for device in first_device.walk_list("NextDevice"):
                device_header = self.profile.Object(
                    "_OBJECT_HEADER", offset=device.obj_offset -
                    device.obj_profile.get_obj_offset("_OBJECT_HEADER", "Body"),
                    vm=device.obj_vm)

                device_name = device_header.NameInfo.Name.cast(
                    vm=self.kernel_address_space)

                yield dict(
                    Type=utils.AttributedString(
                        "DEV", [(0, 30, "WHITE", "BLUE")]),
                    Address=device.obj_offset, Name=device_name,
                    device_type=device.DeviceType,
                    depth=1)

                level = 1

                for att_device in device.walk_list(
                        "AttachedDevice", include_current=False):
                    yield dict(
                        Type=utils.AttributedString(
                            "ATT", [(0, 30, "BLACK", "GREEN")]),
                        Address=att_device.obj_offset, Name=device_name,
                        device_type=att_device.DeviceType,
                        Path=att_device.DriverObject.DriverName,
                        depth=level + 1)

                    level += 1


class DriverIrp(common.PoolScannerPlugin):
    "Driver IRP hook detection"

    __name = "driverirp"

    mod_re = None

    __args = [
        dict(name="regex", type="RegEx",
             help='Analyze drivers matching REGEX'),
    ]

    table_header = [
        dict(name="divider", type="Divider"),
        dict(name="driver", hidden=True),
        dict(name="idx", width=4, align="r"),
        dict(name="function", width=36),
        dict(name="func_addres", style="address"),
        dict(name="name")
    ]

    def generate_hits(self):
        if not self.scan_specification_requested():
            obj_tree_plugin = self.session.plugins.object_tree(
                type_regex="Driver")
            for hit in obj_tree_plugin.collect():
                yield hit["_OBJECT_HEADER"].Object

            return

        for run in self.generate_memory_ranges():
            scanner = filescan.PoolScanDriver(
                session=self.session, profile=self.profile,
                address_space=run.address_space)

            for pool_obj in scanner.scan(run.start, run.length):
                for object_obj in pool_obj.IterObject("Driver", freed=True):
                    yield object_obj.Object

    def collect(self):
        invalid_address = self.session.address_resolver.get_constant_object(
            "nt!IopInvalidDeviceRequest", "Function").obj_offset

        for driver_obj in self.generate_hits():
            driver_name = driver_obj.DriverName.v(vm=self.kernel_address_space)

            # Continue if a regex was supplied and it doesn't match
            if self.plugin_args.regex:
                if not driver_name:
                    continue

                # Continue if a regex was supplied and it doesn't match
                if not self.plugin_args.regex.search(driver_name):
                    continue

            driver_start = driver_obj.DriverStart.v()
            driver_end = driver_obj.DriverStart.v() + driver_obj.DriverSize

            interesting = False
            functions = []
            # Write the address and owner of each IRP function
            for i, function in enumerate(driver_obj.MajorFunction):
                # Make sure this is in the kernel address space.
                function = driver_obj.MajorFunction[i].dereference(
                    vm=self.kernel_address_space)

                func_addres = function.obj_offset
                if func_addres == None:
                    continue

                symbol = utils.FormattedAddress(
                    self.session.address_resolver, func_addres)

                # Suppress function pointers which point at the default invalid
                # address function.
                if (self.plugin_args.verbosity < 5 and
                    func_addres == invalid_address):
                    continue

                highlight = None

                # Functions residing within the driver are not suspicious.
                if not (driver_start <= func_addres <= driver_end):
                    interesting = True
                    # Extra important if we have no idea where it came from.
                    if not self.session.address_resolver.format_address(
                            func_addres):
                        highlight = "important"

                functions.append(dict(driver=driver_obj,
                                      idx=i,
                                      function=function.obj_name,
                                      func_addres=func_addres,
                                      name=symbol,
                                      highlight=highlight))

            # By default only show interesting drivers.
            if (self.plugin_args.verbosity < 2 and not interesting):
                continue

            # Write the standard header for each driver object
            divider = "DriverName: %s %#x-%#x" % (
                driver_name, driver_start, driver_end)

            yield dict(divider=divider)

            for x in functions:
                yield x
